from buffer import ReplayBuffer
import jax
import numpy as np
import haiku as hk
import functools
from absl import app
from absl import flags
from jax.config import config

from mdp import Wrapper, build_MDP
from rl import FRDQN, DQN, act, init_agent_state, update, update_target
from metrics import get_true_error

FLAGS = flags.FLAGS
flags.DEFINE_integer('seed', 42, '')
flags.DEFINE_integer('use_target_net', 0, '')
flags.DEFINE_integer('batch_size', 512, '')
flags.DEFINE_integer('target_update_freq', 10, '')
flags.DEFINE_float('reg_weight', 0.4, '')
flags.DEFINE_float('tau', 1., '')

def main(argv):
    rng_seq = hk.PRNGSequence(FLAGS.seed)
    from rl import update
    
    loss_fn = DQN if FLAGS.use_target_net else FRDQN
    update = functools.partial(update, loss_fn=loss_fn)
    update = jax.jit(update)
    
    env = Wrapper(build_MDP)
    test_env = Wrapper(build_MDP)
    
    def get_avg_reward(rng_seq, agent_state, env):
      returns = []
      for _ in range(10):
        done = False
        state = env.reset()
        return_eps = 0
        while not done:
          action = act(agent_state, state)
          next_state, reward, done, info = env.step(int(action))
          return_eps += reward
          state = next_state
    
        returns.append(return_eps)
      
      return np.array(returns).mean()
    
    
    for i in range(1):
      buffer = ReplayBuffer(1000000)
      agent_state = None
      iteration = 0
      while iteration <= 2e5:
        state = env.reset()
        if agent_state is None:
          agent_state = init_agent_state(next(rng_seq), state)
        done = False
        eps_return = 0
        eps_step = 0
        while not done:
          eps_step += 1
          if np.random.random() <= 0.05:
            action = np.random.randint(4)
          else:
            action = act(agent_state, state)
          action = int(action)
          next_state, reward, done, info = env.step(action)
          buffer.push(state, action, reward, next_state, done)
          state = next_state 
          if buffer.can_sample():
            batch = buffer.sample(FLAGS.batch_size)
            agent_state, stats = update(agent_state, batch, FLAGS.reg_weight)

            if iteration % FLAGS.target_update_freq == 0:
              agent_state = update_target(agent_state, FLAGS.tau)
    
            if iteration % 1e3 == 0:
              metrics = get_true_error(agent_state)
              est_return = get_avg_reward(rng_seq, agent_state, test_env)
              eps_return = {'return': est_return, 'iteration': iteration}

            iteration += 1
    
    
if __name__ == '__main__':
  config.update('jax_platform_name', 'gpu')  # Default to GPU.
  config.update('jax_numpy_rank_promotion', 'raise')
  config.config_with_absl()
  app.run(main)